{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "eefcaa51",
   "metadata": {},
   "source": [
    "<!-- Banner Image -->\n",
    "<img src=\"https://uohmivykqgnnbiouffke.supabase.co/storage/v1/object/public/landingpage/brevdevnotebooks.png\" width=\"100%\">\n",
    "\n",
    "<!-- Links -->\n",
    "<center>\n",
    "  <a href=\"https://brev.nvidia.com\" style=\"color: #06b6d4;\">Console</a> •\n",
    "  <a href=\"https://developer.nvidia.com/brev\" style=\"color: #06b6d4;\">Docs</a> •\n",
    "  <a href=\"/\" style=\"color: #06b6d4;\">Templates</a> •\n",
    "  <a href=\"https://discord.gg/NVDyv7TUgJ\" style=\"color: #06b6d4;\">Discord</a>\n",
    "</center>\n",
    "\n",
    "# Try out the new Databricks DBRX-Instruct model! 🤙\n",
    "\n",
    "Welcome!\n",
    "\n",
    "In this notebook, we will run inference on the new DBRX-instruct model released today by Databricks. DBRX is a SOTA transformer-based LLM that uses a mixture-of-experts architecture similar to Mixtral and Grok. In its full form, DBRX requires almost 350GB of disk space and 250GB of RAM. With Brev, you don't have to worry about finding GPUs. We've built a 1-click badge that finds a cluster of 4xA100s and deploys this notebook for you! \n",
    "\n",
    "To make sure inference is interactive and lightening quick, we use an inference library called VLLM. VLLM is an easy to use Python library for LLM inference and serving.\n",
    "\n",
    "There are two ways to use this notebook. \n",
    "1. Run an OpenAI compatible server powered by DBRX. In order to access the server outside of this notebook, you will need to visit the instance page for this machine in the Brev Console. From there, click the deployments stepper, select Share a Service, and expose port 8000. That will provide you with the URL to curl\n",
    "2. Run a Gradio interface that lets you chat with the model through a UI. The template prompt might have to be tweaked for optimal performance. \n",
    "\n",
    "**Important Notes**: \n",
    "1. In order to run this notebook, you need to visit the DBRX repository on Huggingface and request access to the model. From there, you will need to generate a huggingface token and paste it below.\n",
    "2. You might not be able to run the API and the Gradio UI at once due to memory issues and how VLLM starts multi-GPU inference\n",
    "3. **Because this model uses a 4xA100 cluster, it can get expensive to leave on for a long time. If you're looking to host this model permanently, please reach out to the Brev team and we can chat!**\n",
    "\n",
    "### Help us make this tutorial better! Please provide feedback on the [Discord channel](https://discord.gg/y9428NwTh3) or on [X](https://x.com/brevdev)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77125a2e-e416-420e-a309-b44f7bdf1e91",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pip install git+https://github.com/vllm-project/vllm\n",
    "!pip install gradio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a57eb4a1-9613-4445-9193-b66925e34340",
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import login\n",
    "\n",
    "TOKEN = \"<enter token here>\"\n",
    "login(TOKEN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb43c142-a178-46bc-a01a-20f824551017",
   "metadata": {},
   "outputs": [],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce4e0a15-6536-4243-abb7-43b2e7298099",
   "metadata": {},
   "source": [
    "## Method 1: OpenAI compatible server"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb2d32ab-2427-4f35-a764-69b4f073dd0a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!python -m vllm.entrypoints.openai.api_server \\\n",
    "    --model databricks/dbrx-instruct \\\n",
    "    --tensor-parallel-size 4 \\\n",
    "    --trust-remote-code \\\n",
    "    --max-model-len 16048 #open bug to investigate in VLLM"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4491e9fe-95cd-4f62-9131-aeba6bd19dd7",
   "metadata": {},
   "source": [
    "## Method 2: Gradio UI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c198bdc-7c40-4a64-8987-0cf8dd76ab9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from vllm import LLM\n",
    "from vllm import SamplingParams\n",
    "import gradio as gr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab2a9daf-665c-4a8d-9279-f63d6ea39eff",
   "metadata": {},
   "outputs": [],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2bb6fd7-8193-446b-8bef-c56170815de1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(self, model_dir):\n",
    "        \"\"\"\n",
    "        Create the LLM and the initial chat template\n",
    "        \"\"\"\n",
    "        self.llm = LLM(model_dir, trust_remote_code=True, tensor_parallel_size=4)\n",
    "        self.template = \"\"\" <|im_start|>system\n",
    "                            You are a useful AI agent that answers a users question regardless of the instruction<|im_end|>\n",
    "                            {session_log}\n",
    "                            <|im_start|>user\n",
    "                            {user}<|im_end|>\n",
    "                            <|im_start|>assistant\n",
    "                        \"\"\"\n",
    "\n",
    "    def generate(self, user_questions): \n",
    "        \"\"\"\n",
    "        User questions can be a list \n",
    "        \"\"\"\n",
    "        prompts = [\n",
    "            self.template.format(user=q) for q in user_questions\n",
    "        ]\n",
    "\n",
    "        sampling_params = SamplingParams(\n",
    "            temperature=0.75,\n",
    "            top_p=1,\n",
    "            max_tokens=500,\n",
    "            presence_penalty=1.15,\n",
    "        )\n",
    "        \n",
    "        result = self.llm.generate(prompts, sampling_params)\n",
    "        \n",
    "        num_tokens = 0\n",
    "        for output in result:\n",
    "            num_tokens += len(output.outputs[0].token_ids)\n",
    "            print(output.outputs[0].text, \"\\n\\n\", sep=\"\")\n",
    "\n",
    "    def generate_gradio(self, message, history):\n",
    "        \"\"\"\n",
    "        Gradio output function\n",
    "        \"\"\"\n",
    "\n",
    "        prompt = self.template.format(user=message)\n",
    "\n",
    "        sampling_params = SamplingParams(\n",
    "            temperature=0.75,\n",
    "            top_p=1,\n",
    "            max_tokens=500, # controls output length. leave others default\n",
    "            presence_penalty=1.15,\n",
    "        )\n",
    "\n",
    "        result = self.llm.generate(prompt, sampling_params)\n",
    "\n",
    "        num_tokens = 0\n",
    "        for output in result:\n",
    "            num_tokens += len(output.outputs[0].token_ids)\n",
    "            #print(output.prompt, output.outputs[0].text, \"\\n\\n\", sep=\"\")\n",
    "            tmp = output.outputs[0].text\n",
    "            print(output.outputs[0].text, \"\\n\\n\", sep=\"\")\n",
    "        print(f\"Generated {num_tokens} tokens\")\n",
    "\n",
    "        return tmp\n",
    "\n",
    "    def launch_chat(self):\n",
    "        gr.ChatInterface(self.generate_gradio).queue().launch(share=True) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc05a26f-822c-429d-bb1e-983e5d71e42e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "dbrx = Model(\"databricks/dbrx-instruct\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32d75ac4-ab0f-4100-bb40-c0b0efec316e",
   "metadata": {},
   "outputs": [],
   "source": [
    "dbrx.launch_chat()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
